{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SEGAN_OM\n",
    "\n",
    "a GAN based filter method for speech enhancement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from io import *\n",
    "import os.path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "WINDOWS_SIZE=2**14\n",
    "STRIDE = 0.5\n",
    "sampling_rate=16000\n",
    "SAMPLING = tf.constant(sampling_rate,dtype=tf.int32,shape=())\n",
    "KERNEL_SIZE=31\n",
    "BATCH_SIZE = 10 # used for loading the data \n",
    "EPOCHS = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) Data pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = r\".\\Dataset\\clean\"\n",
    "path_noisy = r\".\\Dataset\\noisy\"\n",
    "files_number = len(os.listdir(path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following functions can be used to convert a value to a type compatible\n",
    "def _bytes_feature(value):\n",
    "    \"\"\"Returns a bytes_list from a string / byte.\"\"\"\n",
    "    if isinstance(value, type(tf.constant(0))):\n",
    "        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "def _float_feature(value):\n",
    "    \"\"\"Returns a float_list from a float / double.\"\"\"\n",
    "    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))\n",
    "\n",
    "def _int64_feature(value):\n",
    "    \"\"\"Returns an int64_list from a bool / enum / int / uint.\"\"\"\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def slice_signal(signal, window_size=2**14, stride=0.5):\n",
    "    \"\"\" Return windows of the given signal by sweeping in stride fractions\n",
    "        of window\n",
    "    \"\"\"\n",
    "    #assert signal.ndim == 1, signal.ndim\n",
    "    n_samples = signal.shape[0]\n",
    "    overlap = int(window_size * stride)\n",
    "    slices = []\n",
    "    for beg_i in range(0, n_samples, overlap):\n",
    "        end_i = beg_i + window_size\n",
    "        slice_ = signal[beg_i:end_i]\n",
    "        if slice_.shape[0] == window_size:\n",
    "            slices.append(slice_)\n",
    "    return np.array(slices, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_tf_example(track_1, track_2):\n",
    "    features={\n",
    "        'clean': _bytes_feature(track_1),\n",
    "        'noisy': _bytes_feature(track_2),\n",
    "    }\n",
    "    \n",
    "    return tf.train.Example(features=tf.train.Features(feature=features))\n",
    "\n",
    "#feature = _bytes_feature(clean_list[0][0].tobytes())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "file 40 from 600 written\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "record_file = 'small_waves.tfrecords'\n",
    "with tf.io.TFRecordWriter(record_file) as writer:\n",
    "    for file in os.listdir(path):  # iterate over each image\n",
    "        i+=1\n",
    "        name, _ = os.path.splitext(file)\n",
    "        file_path = os.path.join(path,file)\n",
    "        file = tf.io.read_file(file_path)\n",
    "        clean, sample_rate = tf.audio.decode_wav(file) # returns 2 objs: tf.Tensor(sample_rate, shape=(), dtype=int32), tf.Tensor([[x]...], shape=(46797, 1), dtype=float32)\n",
    "        if not tf.math.equal(SAMPLING,sample_rate):\n",
    "            raise ValueError(f'Sampling rate of clean is expected to be {SAMPLING}! Got {sample_rate}')\n",
    "        file_path = os.path.join(path_noisy, name + '_CAFE_SNR_0.wav')\n",
    "        file = tf.io.read_file(file_path)\n",
    "        noisy, sample_rate = tf.audio.decode_wav(file)\n",
    "        if not tf.math.equal(SAMPLING,sample_rate):\n",
    "            raise ValueError(f'Sampling rate of noisy is expected to be {SAMPLING}! Got {sample_rate}')\n",
    "        seq_clean = slice_signal(clean,WINDOWS_SIZE,STRIDE)\n",
    "        seq_noisy = slice_signal(noisy,WINDOWS_SIZE,STRIDE)\n",
    "        for track_clean, track_noisy in zip(seq_clean, seq_noisy):\n",
    "            track_clean = track_clean.tostring()\n",
    "            track_noisy = track_noisy.tostring()\n",
    "            tf_example = make_tf_example(track_clean, track_noisy)\n",
    "            writer.write(tf_example.SerializeToString())\n",
    "        clear_output()\n",
    "        print(f\"file {i} from {files_number} written\")\n",
    "        if i == 80:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Layers & Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.1: Build the Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def downsample(filter_width, kernel=31, #size, \n",
    "              strides = 2, padding = 'same', init= None):\n",
    "    \"\"\"\n",
    "    creates a 1D-Conv-Block for the Generator with given kernel & filters.\n",
    "    \n",
    "    Arguments:\n",
    "    filter_size -- tf.keras.Conv1D.filters\n",
    "    kernel -- tf.keras.Conv1D.kernel_size, set to 31 for this application\n",
    "    strides -- optional, default is '2' for this application\n",
    "    padding -- optional, default is 'same'\n",
    "    init -- weights initializer, will be set to He is none is given\n",
    "    \n",
    "    Returns:\n",
    "    block -- tf.Tensor block of a 1D-Conv\n",
    "    \"\"\"\n",
    "    # set the initializer if none is given\n",
    "    if init is None:\n",
    "        init = tf.keras.initializers.he_normal()\n",
    "    \n",
    "    # make the convolutional block\n",
    "    block = tf.keras.Sequential()\n",
    "    block.add(tf.keras.layers.Conv1D(filters = filter_width, kernel_size = kernel, strides=strides,\n",
    "                                     #(kernel, 1), strides=(strides, 1), #for conv2d\n",
    "                                     padding=padding, kernel_initializer=init, use_bias=False))\n",
    "    # add the activation function\n",
    "    block.add(tf.keras.layers.PReLU())\n",
    "    \n",
    "    return block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def upsample(filter_width, kernel=31, #size, \n",
    "              strides = 2, padding = 'same', init= None):\n",
    "    \"\"\"\n",
    "    creates a 1D-Deconv-Block for the Generator with given kernel & filters.\n",
    "    \n",
    "    Arguments:\n",
    "    filter_size -- tf.keras.Conv1D.filters\n",
    "    kernel -- tf.keras.Conv1D.kernel_size, set to 31 for this application\n",
    "    strides -- optional, default is '2' for this application\n",
    "    padding -- optional, default is 'same'\n",
    "    init -- weights initializer, will be set to He is none is given\n",
    "    \n",
    "    Returns:\n",
    "    block -- tf.Tensor block of a 1D-Conv\n",
    "    \"\"\"\n",
    "    # set the initializer if none is given\n",
    "    if init is None:\n",
    "        init = tf.keras.initializers.he_normal()\n",
    "    \n",
    "    # make the convolutional block\n",
    "    block = tf.keras.Sequential()\n",
    "    block.add(tf.keras.layers.Conv2DTranspose(filters = filter_width, kernel_size = (kernel, 1), strides=(strides, 1),\n",
    "                                     padding=padding, kernel_initializer=init, use_bias=False))\n",
    "    \n",
    "    \n",
    "    # add the activation function\n",
    "    block.add(tf.keras.layers.LeakyReLU())\n",
    "    \n",
    "    return block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Generator():\n",
    "    inputs = tf.keras.layers.Input(shape=[2**14, 1])\n",
    "    #inputs = tf.keras.backend.expand_dims(inputs, axis=1)\n",
    "    \n",
    "    down_stack = [\n",
    "        downsample(16, 16384),\n",
    "        downsample(32, 8192),\n",
    "        downsample(32, 4096),\n",
    "        downsample(64, 2048),\n",
    "        #downsample(64, 1024),\n",
    "        #downsample(128, 512),\n",
    "        #downsample(128, 256),\n",
    "        #downsample(256, 128),\n",
    "        #downsample(256,  64),\n",
    "        #downsample(512,  32),\n",
    "        #downsample(1024, 16),\n",
    "    ]\n",
    "\n",
    "    up_stack = [\n",
    "        #upsample(512,  32),\n",
    "        #upsample(256,  64),\n",
    "        #upsample(256, 128),\n",
    "        #upsample(128, 256),\n",
    "        #upsample(128, 512),\n",
    "        #upsample(64, 1024),\n",
    "        #upsample(64, 2048),\n",
    "        upsample(32, 4096),\n",
    "        upsample(32, 8192),\n",
    "        upsample(16, 16382),\n",
    "    ]\n",
    "\n",
    "    x = inputs\n",
    "\n",
    "    # Downsampling through the model\n",
    "    skips = []\n",
    "    for down in down_stack:\n",
    "        x = down(x)\n",
    "        skips.append(x)\n",
    "        \n",
    "    skips = reversed(skips[:-1])\n",
    "    \n",
    "    # Upsampling and establishing the skip connections\n",
    "    for up, skip in zip(up_stack, skips):\n",
    "    #for up in up_stack:\n",
    "        x = tf.keras.backend.expand_dims(x, axis=2)\n",
    "        x = up(x)\n",
    "        x = tf.keras.backend.squeeze(x, axis=2)\n",
    "        x = tf.keras.layers.Concatenate()([x, skip])\n",
    "\n",
    "        \n",
    "    x = tf.keras.backend.expand_dims(x, axis=2)\n",
    "    x = upsample(1, 32768)(x)\n",
    "    x = tf.keras.backend.squeeze(x, axis=2)\n",
    "\n",
    "    return tf.keras.Model(inputs=inputs, outputs=x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = Generator()\n",
    "generator.summary()\n",
    "#tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2 Build Discriminator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_waves_dataset = tf.data.TFRecordDataset('small_waves.tfrecords',num_parallel_reads= tf.data.experimental.AUTOTUNE)\n",
    "raw_waves_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_wave_function(example_proto):\n",
    "    feature_description = {\n",
    "        'clean': tf.io.FixedLenFeature([], tf.string),\n",
    "        'noisy': tf.io.FixedLenFeature([], tf.string),\n",
    "    }\n",
    "    return tf.io.parse_single_example(example_proto, feature_description)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _decode_parsed_wave(parsed_wave):\n",
    "    clean_raw = parsed_wave['clean']\n",
    "    noisy_raw = parsed_wave['noisy']\n",
    "    return tf.io.decode_raw(clean_raw,tf.float32), tf.io.decode_raw(noisy_raw,tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _correct_dim_wave(dec_clean, dec_noisy):\n",
    "    return tf.expand_dims(dec_clean, 1), tf.expand_dims(dec_noisy, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parsed_waves_dataset = raw_waves_dataset.map(_parse_wave_function)\n",
    "parsed_waves_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded_waves_dataset = parsed_waves_dataset.map(_decode_parsed_wave)\n",
    "decoded_waves_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corrected_waves_dataset = decoded_waves_dataset.map(_correct_dim_wave)\n",
    "corrected_waves_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = corrected_waves_dataset.shuffle(200)\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_set = dataset.take(200) #2677-1870\n",
    "testing_set  = dataset.skip(200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_set = training_set.batch(10, drop_remainder=True)\n",
    "testing_set  = testing_set.batch(10, drop_remainder=True)\n",
    "training_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testing_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_length = [i for i,_ in enumerate(testing_set)][-1] + 1\n",
    "print(dataset_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4 Train the model\n",
    "\n",
    "In order to understand the functionning of the generator we will train it over 50 epochs using the l1-loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generator_loss(gen_output, target):\n",
    "    # mean absolute error\n",
    "    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
    "    return l1_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = './training_checkpoints'\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
    "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
    "                                 generator=generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "log_dir=\"logs/\"\n",
    "\n",
    "summary_writer = tf.summary.create_file_writer(log_dir + \"fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(input_wave, target, epoch):\n",
    "    with tf.GradientTape() as gen_tape:\n",
    "        gen_output = generator(input_wave, training=True)\n",
    "        gen_l1_loss = generator_loss(gen_output, target)\n",
    "    \n",
    "    generator_gradients = gen_tape.gradient(gen_l1_loss, generator.trainable_variables)\n",
    "    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))\n",
    "        with summary_writer.as_default():\n",
    "            tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(train_ds, epochs, test_ds):\n",
    "    for epoch in range(epochs):\n",
    "        start = time.time()\n",
    "        start_2 = time.time()\n",
    "        #for n, (input_wave,target) in train_ds.enumerate():\n",
    "        for n, (target, input_wave) in train_ds.enumerate():\n",
    "            print(f'epoch {epoch}, batch {n}')\n",
    "            train_step(input_wave, target, epoch)\n",
    "            print(time.time()-start_2)\n",
    "            start_2 = time.time()\n",
    "            \n",
    "        # saving (checkpoint) the model every epoch\n",
    "        checkpoint.save(file_prefix = checkpoint_prefix)\n",
    "        #clear_output()\n",
    "        print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1, time.time()-start))\n",
    "        checkpoint.save(file_prefix = checkpoint_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir {log_dir}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit(training_set, 5, testing_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!kill 11924"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
